{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "Copyright 2021 Google LLC.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \\\"License\\\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \\\"AS IS\\\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License."
      ],
      "metadata": {
        "id": "fRC4mvBYeS1f"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Dimensions of Motion colab\n",
        "\n",
        "This brief notebook accompanies code for the paper __Dimensions of Motion: Monocular Prediction through Flow Subspaces__, and may be found at\n",
        "https://github.com/google-research/google-research/tree/master/dimensions_of_motion.\n",
        "\n",
        "The project page is at https://dimensions-of-motion.github.io/.\n",
        "\n",
        "Choose __Run all__ from the Runtime menu to:\n",
        "* install required packages and download our code,\n",
        "* compute and visualize a flow basis and reconstruction losses from an example embedding and disparity,\n",
        "* show how to create a model and compute losses.\n",
        "\n",
        "Please feel free to contact the authors if you have questions.\n",
        "\n"
      ],
      "metadata": {
        "id": "96VW3NRneW03"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Install packages, download code and example."
      ],
      "metadata": {
        "id": "tGFwtHBZeHyJ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "%%shell\n",
        "\n",
        "# Download code\n",
        "apt install subversion\n",
        "svn export --force https://github.com/google-research/google-research/trunk/dimensions_of_motion\n",
        "\n",
        "# Install required libraries\n",
        "pip install -r dimensions_of_motion/requirements.txt\n",
        "\n",
        "# A simple example\n",
        "mkdir example\n",
        "cd example\n",
        "wget https://dimensions-of-motion.github.io/example/source_image.npy\n",
        "wget https://dimensions-of-motion.github.io/example/target_image.npy\n",
        "wget https://dimensions-of-motion.github.io/example/flow.npy\n",
        "wget https://dimensions-of-motion.github.io/example/disparity.npy\n",
        "wget https://dimensions-of-motion.github.io/example/embedding.npy\n"
      ],
      "metadata": {
        "id": "RETUiGB1dH_d"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Compute and visualize basis flow fields"
      ],
      "metadata": {
        "id": "csgxaiwQn7XK"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2UPQ8wdXbUpn"
      },
      "outputs": [],
      "source": [
        "# Make sure python can find our libraries\n",
        "import sys\n",
        "sys.path.append('dimensions_of_motion')\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "\n",
        "import tensorflow as tf\n",
        "import tensorflow_addons as tfa\n",
        "\n",
        "import palettable.matplotlib\n",
        "import flow_vis\n",
        "\n",
        "import loss\n",
        "\n",
        "\n",
        "def show_image(data):\n",
        "  # Just show first image in batch\n",
        "  data = data[0]\n",
        "  channels = data.shape[-1]\n",
        "  if channels == 3:\n",
        "    plt.imshow(data)\n",
        "  elif channels == 2:  # flow map\n",
        "    plt.imshow(flow_vis.flow_to_color(data))\n",
        "  elif channels == 1:  # disparity\n",
        "    plt.imshow(data[..., 0],\n",
        "               cmap=palettable.matplotlib.Viridis_20.mpl_colormap, vmin=0)\n",
        "  else:\n",
        "    raise ValueError(f'Bad image data (channels={channels}).')\n",
        "\n",
        "plt.rcParams[\"figure.figsize\"] = (20,3)\n",
        "\n",
        "def show_images(names_and_images, images_per_row=4):\n",
        "  n = len(names_and_images)\n",
        "  rows = ((n-1) // images_per_row) + 1\n",
        "  for i in range(n):\n",
        "    if (i % images_per_row == 0) and i > 0:\n",
        "      plt.show()\n",
        "    plt.subplot(1, images_per_row, (i % images_per_row)+1)\n",
        "    show_image(names_and_images[i][1])\n",
        "    plt.axis('off')\n",
        "    plt.title(names_and_images[i][0])\n",
        "  plt.show()\n",
        "\n",
        "\n",
        "# Load numpy data and return as a tensor with a batch size of 1.\n",
        "def load(f):\n",
        "  return tf.convert_to_tensor(np.load(f))[tf.newaxis]\n",
        "\n",
        "image = load('example/source_image.npy')\n",
        "target = load('example/target_image.npy')\n",
        "\n",
        "flow = load('example/flow.npy')\n",
        "\n",
        "# Scene representation (output from running model on input image)\n",
        "disparity = load('example/disparity.npy')\n",
        "embeddings = load('example/embedding.npy')\n",
        "\n",
        "losses, summaries = loss.compute_motion_loss(\n",
        "    image, flow, disparity, embeddings)\n",
        "\n",
        "show_images([['input', image], ['target', target], ['flow', flow]])\n",
        "\n",
        "for (key, (value, weight)) in losses.items():\n",
        "  print(f'{key} = {value[0].numpy(), weight}')\n",
        "\n",
        "show_images(list(summaries.items()))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Build example model and compute losses\n",
        "This codeblock shows how to build a model with 6 embedding dimensions and run\n",
        "it on an example batch of one image, and uses ground truth flow to compute\n",
        "losses.\n"
      ],
      "metadata": {
        "id": "mGfsOaJ-mnXo"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import models\n",
        "\n",
        "BATCH_SIZE = 1\n",
        "WIDTH = 256\n",
        "HEIGHT = 256\n",
        "EMBEDDING_DIMENSION = 6\n",
        "\n",
        "def model():\n",
        "  session = tf.compat.v1.Session()\n",
        "  with session.graph.as_default():\n",
        "    image = tf.compat.v1.placeholder(tf.float32, [BATCH_SIZE, HEIGHT, WIDTH, 3])\n",
        "    flow = tf.compat.v1.placeholder(tf.float32, [BATCH_SIZE, HEIGHT, WIDTH, 2])\n",
        "    losses, summaries = models.run_model(image, flow, EMBEDDING_DIMENSION, None)\n",
        "    session.run(tf.compat.v1.global_variables_initializer())\n",
        "\n",
        "  def as_numpy(x):\n",
        "    if tf.is_tensor(x):\n",
        "      return x.numpy()\n",
        "    else:\n",
        "      return x\n",
        "\n",
        "  def run_model(input_image, input_flow):\n",
        "    return session.run([losses, summaries], feed_dict={\n",
        "        image: as_numpy(input_image),\n",
        "        flow: as_numpy(input_flow)\n",
        "    })\n",
        "  return run_model\n",
        "\n",
        "run_model = model()\n",
        "\n",
        "print('Running untrained model on input image')\n",
        "show_images([['image', image]])\n",
        "\n",
        "losses, summaries = run_model(image, flow)\n",
        "print('Losses')\n",
        "for l in losses:\n",
        "  print('   ', l, losses[l])\n",
        "\n",
        "print('\\nSummary images')\n",
        "show_images(list(summaries.items()))"
      ],
      "metadata": {
        "id": "wpdduU_9OW66"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}